/*
 * ProgramStreamParser.cpp
 *
 *  Created: 2013-07-07
 *   Author: terry
 */

#include "stdafx.h"
#include "ProgramStreamParser.h"
#include <assert.h>

ProgramStreamParser::ProgramStreamParser():
m_pSink(),
m_lastType()
{
    m_buffer.ensure(1024 * 500);
    m_compoundBuffer.ensure(1024 * 700);
}

ProgramStreamParser::~ProgramStreamParser()
{

}

bool ProgramStreamParser::inputData(const uint8_t* buffer, size_t length)
{
    if (length == 0)
    {
        return false;
    }

    m_buffer.write(buffer, length);

    bool found = false;
    while (true)
    {
        uint8_t* data = m_buffer.getReadPtr();
        size_t size = m_buffer.readable();

        StreamPacket firstPacket;
        if (!findPacket(data, size, 0, firstPacket))
        {
            break;
        }

        StreamPacket secondPacket;
        if (!findPacket(data, size, firstPacket.length + 3, secondPacket))
        {
            break;
        }

        firstPacket.length = secondPacket.length - firstPacket.length;
        onNewPacket(firstPacket);
        m_buffer.skip(secondPacket.length);
        found = true;
    }

    return found;
}

void ProgramStreamParser::setSink(NaluAnalyzerSink* pSink)
{
    m_pSink = pSink;
}

void ProgramStreamParser::clear()
{
    m_buffer.clear();
}

void ProgramStreamParser::onNewPacket(StreamPacket& packet)
{
    if (packet.type >= 0xBA && packet.type <= 0xBF)    // PS header
    {
        flushCompoundBuffer();
    }
    else if (packet.type >= 0xC0 && packet.type <= 0xDF) // audio
    {
        flushCompoundBuffer();
    }
    else if (packet.type >= 0xE0 && packet.type <= 0xEF)    // video stream
    {
        onPESPacket(packet);
    }
    else
    {
        flushCompoundBuffer();
    }

}


void ProgramStreamParser::onPESPacket(StreamPacket& packet)
{
    if (packet.length < 9)
    {
        //invalid PES
        return ;
    }

    UINT16 pktLength = MAKEWORD(packet.data[5], packet.data[4]);
    uint8_t hdrLength = packet.data[8];
    uint8_t flags = packet.data[6];

    uint8_t* pData = packet.data + 9 + hdrLength;
    UINT16 length = packet.length - (9 + hdrLength);

    if (pktLength > packet.length)
    {
        return;
    }

    m_compoundBuffer.write(pData, length);
}

void ProgramStreamParser::writePacket(NaluPacket& packet)
{
    if (!m_pSink)
    {
        return;
    }

    m_pSink->writePacket(packet);
}

void ProgramStreamParser::flushCompoundBuffer()
{
    if (m_compoundBuffer.empty())
    {
        return;
    }

    NaluPacket pkt;
    pkt.data = m_compoundBuffer.getReadPtr();
    pkt.length = m_compoundBuffer.readable();
    pkt.prefix = 0;
    pkt.type = NaluPacket::NALU_NULL;
    
    searchPrefix(pkt.data, pkt.length, pkt.prefix);
    if (pkt.prefix > 0)
    {
        pkt.type = pkt.data[pkt.prefix] & 0xff;
    }
    
    if (pkt.prefix == 0)
    {
        m_compoundBuffer.clear();
    }
    else
    {
        writePacket(pkt);
        m_compoundBuffer.clear();
    }
    
}

bool ProgramStreamParser::findPacket(uint8_t* buffer, size_t length, size_t start, NaluPacket& packet)
{
    if ((length < 3) || ((length - start) < 3))
    {
        return false;
    }

    bool found = false;
    uint8_t* p = buffer;
    for (size_t i = start; i < (length - 3); ++ i)
    {
        if ((p[i] == 0) && (p[i+1] == 0))
        {
            if (p[i+2] == 0)
            {
                if (((i + 3) < length) && (p[i+3] == 1))
                {
                    packet.data = p + i;
                    packet.length = i;
                    packet.prefix = 4;
                    packet.type = packet.data[4];

                    found = true;
                    break;
                }
            }
            else if (p[i+2] == 1)
            {
                packet.data = p + i;
                packet.length = i;
                packet.prefix = 3;
                packet.type = packet.data[3];

                found = true;
                break;
            }
        }
    }
    return found;
}

bool ProgramStreamParser::findPacket(uint8_t* buffer, size_t length, size_t start, StreamPacket& packet)
{
    if ((length < 4) || ((length - start) < 4))
    {
        return false;
    }

    bool found = false;
    uint8_t* p = buffer;
    for (size_t i = start; i < (length - 3); ++ i)
    {
        if ((p[i] == 0) && (p[i+1] == 0) && p[i+2] == 1)
        {
            if (isStreamID(p[i+3]))
            {
                packet.data = p + i;
                packet.length = i;
                packet.type = p[i+3];

                found = true;
                break;
            }
        }
    }
    return found;
}

bool ProgramStreamParser::isStreamID(uint8_t ch)
{
    return (ch >= 0xBA);
}

bool ProgramStreamParser::searchPrefix(uint8_t* buffer, size_t length, int& prefix)
{
    if (length < 3)
    {
        return false;
    }

    bool found = false;
    if (buffer[0] == 0 && buffer[1] == 0)
    {
        if (buffer[2] == 1)
        {
            prefix = 3;
            found = true;
        }
        else if (buffer[2] == 0)
        {
            if (length >= 4)
            {
                if (buffer[3] == 1)
                {
                    prefix = 4;
                    found = true;
                }
            }
        }
    }
    return found;
}

